{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "714c529f",
   "metadata": {},
   "source": [
    "# RemoveWeightLayoutRewriteBlock"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "623f66bd",
   "metadata": {},
   "source": [
    "\n",
    "该测试文件主要用于验证 TVM TIR 中的 `RemoveWeightLayoutRewriteBlock` 转换功能。这个转换的核心作用是优化深度学习模型中的权重布局处理，具体来说：\n",
    "\n",
    "1. **优化背景**：在深度学习编译中，为了提高计算效率，通常需要对权重张量进行特定的布局转换（如分块、重排等）。原始代码中通常会在计算前先执行一个专门的布局转换块。\n",
    "\n",
    "2. **转换目的**：`RemoveWeightLayoutRewriteBlock` 转换的目的是消除这个显式的布局转换块，将权重布局转换操作集成到计算过程中，从而减少内存访问和数据传输开销。\n",
    "\n",
    "3. **主要优化点**：\n",
    "   - 移除显式的权重布局转换块，消除临时缓冲区的分配和填充\n",
    "   - 修改计算块，使其直接使用重排后的权重格式\n",
    "   - 保留布局重写块的框架但将其内容替换为空操作\n",
    "   - 调整函数签名，使输入权重直接使用重排后的形状\n",
    "\n",
    "4. **测试方法**：通过定义 `before` 和 `after` 两个 TIR 原语函数，分别表示转换前后的函数形式，然后使用 `_check` 函数验证应用转换后是否与预期结果一致。测试场景是一个 16×16 的矩阵乘法，重点验证权重重排从 (16,16) 到 (16,4,4) 的优化。\n",
    "\n",
    "这种优化对于提升深度学习模型在硬件上的执行效率非常重要，特别是在专用加速器或需要特定数据布局以最大化计算效率的场景中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cf6e3ee5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入必要的系统库\n",
    "import sys\n",
    "\n",
    "# 导入 TVM 相关模块\n",
    "import tvm\n",
    "from tvm.ir.module import IRModule\n",
    "from tvm.script import tir as T\n",
    "from tvm.tir.function import PrimFunc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "54ffb7f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义辅助函数，用于检查转换前后的 IR 模块是否结构等价\n",
    "def _check(before, expect):\n",
    "    # 如果输入是 PrimFunc，将其转换为 IRModule 格式\n",
    "    if isinstance(before, PrimFunc):\n",
    "        before = IRModule({\"main\": before.with_attr(\"global_symbol\", \"main\")})\n",
    "    if isinstance(expect, PrimFunc):\n",
    "        expect = IRModule({\"main\": expect.with_attr(\"global_symbol\", \"main\")})\n",
    "\n",
    "    # 应用 RemoveWeightLayoutRewriteBlock 转换\n",
    "    mod = tvm.tir.transform.RemoveWeightLayoutRewriteBlock()(before)\n",
    "    # 验证转换后的模块与预期模块是否结构等价\n",
    "    tvm.ir.assert_structural_equal(mod, expect)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f04a8799",
   "metadata": {},
   "source": [
    "## 测试矩阵乘法中的权重布局重写块移除功能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e63f9d66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义转换前的 TIR 原语函数\n",
    "@T.prim_func\n",
    "def before(\n",
    "    A: T.Buffer((16, 16), \"float32\"),  # 矩阵 A，形状为 (16, 16)，数据类型为 float32\n",
    "    B: T.Buffer((16, 16), \"float32\"),  # 矩阵 B，形状为 (16, 16)，数据类型为 float32\n",
    "    C: T.Buffer((16, 16), \"float32\"),  # 结果矩阵 C，形状为 (16, 16)，数据类型为 float32\n",
    ") -> None:\n",
    "    # 函数属性，指定 B 是布局自由缓冲区（layout free buffer）\n",
    "    T.func_attr({\"layout_free_buffers\": [1]})\n",
    "    # 分配临时缓冲区 B_，用于存储重排后的权重\n",
    "    B_ = T.alloc_buffer([16, 4, 4], dtype=\"float32\")\n",
    "    \n",
    "    # 布局重写块：将矩阵 B 从 (16,16) 重排为 (16,4,4) 格式\n",
    "    for i0_o, i1_o in T.grid(16, 16):\n",
    "        with T.block(\"layout_rewrite\"):\n",
    "            i0, i1 = T.axis.remap(\"SS\", [i0_o, i1_o])\n",
    "            T.reads(B[i0, i1])\n",
    "            T.writes(B_[i1, i0 // 4, i0 % 4])\n",
    "            # 标记该块为布局重写预处理块\n",
    "            T.block_attr({\"meta_schedule.layout_rewrite_preproc\": True})\n",
    "            # 执行布局转换操作：将 B[i0, i1] 重排为 B_[i1, i0//4, i0%4]\n",
    "            B_[i1, i0 // 4, i0 % 4] = B[i0, i1]\n",
    "    \n",
    "    # 矩阵乘法块：使用重排后的权重缓冲区 B_ 进行计算\n",
    "    for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):\n",
    "        with T.block(\"matmul\"):\n",
    "            vi = T.axis.spatial(16, i0 * 4 + i1)  # 行索引，使用分块索引重构\n",
    "            vj = T.axis.spatial(16, j)             # 列索引\n",
    "            vk = T.axis.reduce(16, k0 * 4 + k1)    # 归约索引，使用分块索引重构\n",
    "            T.reads(A[vi, vk], B_[vj, vk // 4, vk % 4])  # 声明读取的内存位置\n",
    "            T.writes(C[vi, vj])                          # 声明写入的内存位置\n",
    "            with T.init():\n",
    "                C[vi, vj] = T.float32(0)                 # 初始化结果为0\n",
    "            # 执行矩阵乘法的累加操作，注意这里使用了重排后的权重 B_ 的索引方式\n",
    "            C[vi, vj] = C[vi, vj] + A[vi, vk] * B_[vj, vk // 4, vk % 4]\n",
    "\n",
    "# 定义应用 RemoveWeightLayoutRewriteBlock 转换后的预期 TIR 原语函数\n",
    "@T.prim_func\n",
    "def after(\n",
    "    A: T.Buffer((16, 16), \"float32\"),      # 矩阵 A，保持不变\n",
    "    B: T.Buffer((16, 4, 4), \"float32\"),    # 注意：矩阵 B 现在直接使用重排后的形状 (16,4,4)\n",
    "    C: T.Buffer((16, 16), \"float32\"),      # 结果矩阵 C，保持不变\n",
    ") -> None:\n",
    "    # 保留布局自由缓冲区属性\n",
    "    T.func_attr({\"layout_free_buffers\": [1]})\n",
    "    \n",
    "    # 布局重写块被保留但内容被修改：\n",
    "    # 1. 移除了对原始 B 矩阵的读取\n",
    "    # 2. 移除了对临时缓冲区 B_ 的写入\n",
    "    # 3. 替换为一个空操作 T.evaluate(0)\n",
    "    for i0_o, i1_o in T.grid(16, 16):\n",
    "        with T.block(\"layout_rewrite\"):\n",
    "            i0, i1 = T.axis.remap(\"SS\", [i0_o, i1_o])\n",
    "            T.reads()\n",
    "            T.writes()\n",
    "            T.block_attr({\"meta_schedule.layout_rewrite_preproc\": True})\n",
    "            T.evaluate(0)  # 空操作\n",
    "    \n",
    "    # 矩阵乘法块：\n",
    "    # 关键变化是现在直接使用输入的 B 矩阵（已重排为 (16,4,4) 形状），\n",
    "    # 而不是之前的临时缓冲区 B_\n",
    "    for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):\n",
    "        with T.block(\"matmul\"):\n",
    "            vi = T.axis.spatial(16, i0 * 4 + i1)\n",
    "            vj = T.axis.spatial(16, j)\n",
    "            vk = T.axis.reduce(16, k0 * 4 + k1)\n",
    "            # 关键变化：这里直接使用 B[vj, vk//4, vk%4] 而不是 B_\n",
    "            T.reads(A[vi, vk], B[vj, vk // 4, vk % 4])\n",
    "            T.writes(C[vi, vj])\n",
    "            with T.init():\n",
    "                C[vi, vj] = T.float32(0)\n",
    "            # 计算部分保持相同的索引方式，但现在直接使用重排后的 B 矩阵\n",
    "            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk // 4, vk % 4]\n",
    "\n",
    "# 运行检查，验证 RemoveWeightLayoutRewriteBlock 转换是否按预期工作\n",
    "_check(before, after)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py313",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
